{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, learning_rate):\n",
    "        \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.BasicRNNCell(size_layer,reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))\n",
    "        \n",
    "        _, encoder_state = tf.nn.dynamic_rnn(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), \n",
    "            inputs = tf.nn.embedding_lookup(embeddings, self.X),\n",
    "            sequence_length = self.X_seq_len,\n",
    "            dtype = tf.float32)\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        \n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.fast_result = predicting_decoder_output.sample_id\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 512\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[ 8092,   364, 23527, 16731, 12432, 24937, 21081, 21142, 14804,\n",
       "         18234, 25024, 14908, 27933, 23633,  8088, 26961, 10696,  8350,\n",
       "          2336,  1560,  9475,  7028,  1952,  9737, 27888, 27603, 21554,\n",
       "          5376, 30761, 24453, 14154, 25407, 13988, 11466, 27134, 17576,\n",
       "         10293,   435, 28450, 28138, 31434,  2669,  9231, 21043, 29167,\n",
       "          3865, 12123, 26151, 22312, 20040, 16020,  7213, 14383,  6306,\n",
       "         17745, 20872, 21499, 20713, 27365,  4323, 30281,   647, 23627,\n",
       "         10903, 12831, 24343,   869, 27604, 15795, 15732, 28700,  7564],\n",
       "        [13267, 25553,   398, 27889,  9162, 18628, 28777,  7077, 25680,\n",
       "         10341, 31146, 16193, 15781,  5684, 20317,  1807,  3450,  9648,\n",
       "         19283, 18304, 14013, 11349,  8453,  3195,   463, 20584,  8176,\n",
       "         19834, 13095, 24631, 26027,  5502,  2789, 13208, 19939, 28314,\n",
       "          7387, 16891, 23451, 19947, 17939,  7440,  4705,  5147, 29115,\n",
       "          9981, 21912,  9427, 19677, 30367, 11677, 18783, 29427, 22706,\n",
       "         27585, 26323, 22117, 12203, 12237, 17905, 25664, 28479, 23885,\n",
       "          8847, 29361, 13901, 25015, 30020,  6621, 18262, 20044,   793],\n",
       "        [25888, 13104, 19073, 13882, 25347, 28031, 16980, 13663, 30630,\n",
       "         21524, 22420, 20476, 19504,  3087, 18376, 20286, 19796, 21386,\n",
       "         11998, 14303, 11230, 31254,  9698,  2441, 18495,   563, 20120,\n",
       "          4043, 26373,  2561, 13558, 21822, 11268, 20936, 17852, 30478,\n",
       "         15108, 23839, 17263,  1781, 10493, 27366, 28273, 19574, 22308,\n",
       "         12773, 15046, 21778,  5379, 17279, 17848, 16524,  3791,  9124,\n",
       "          1470, 11228, 20237, 10543,  1773, 30716,  4949, 23964, 30918,\n",
       "          1606, 15265,  2073,  1912,  2732,  3298,  6347,  4863, 16785],\n",
       "        [ 3953, 19172,  7860,  7226, 15206, 31506, 13581, 12356,  2424,\n",
       "         29176, 28383,  1964, 15063,  3373, 30104, 22827, 26969, 31149,\n",
       "          8929, 21050, 17803, 26325, 20349,   170, 21349,  9271,  6938,\n",
       "          1115, 28638, 27742,  6784, 22036, 24586, 25323,  1397,  1649,\n",
       "         21863,  2279, 22802,  4233,  4134, 29405,  4721, 17180, 11102,\n",
       "         31819, 28859,   476, 23177, 29408, 24049,  5026, 13457,  2267,\n",
       "         23273, 18454, 14555, 19882, 15548, 19311, 27900, 31037, 24371,\n",
       "         12639,  4716, 23128, 18700, 13460,  1223, 17807, 14073,  6324],\n",
       "        [28073, 14887,  8094, 29888, 30871, 22065,  3576, 10840, 15791,\n",
       "         25776, 18585, 12696,  3850, 24351,  2267,  2993, 13708,  8596,\n",
       "         22762, 12654,  7751, 27027, 29957, 26241, 20083,  4850, 10905,\n",
       "         15395, 17023, 26495, 19274, 15869, 19036, 27350, 14358, 21701,\n",
       "         12954, 24876,  8412, 19410, 18644,  4436,  8881, 28932, 28105,\n",
       "          9048, 20711, 25427, 26394, 19509, 26426,  5764, 12757, 25558,\n",
       "         19141, 13606, 31124, 11529, 12995,  6525, 18384, 30323, 21503,\n",
       "          4762,  1144, 25314, 25638, 17143,  7943, 22824, 22665, 14542],\n",
       "        [ 8734, 14774, 15053, 22977, 11008, 14026, 15342,  3554,   854,\n",
       "         17472, 23770, 15259, 26244, 11156, 16844,  1175, 26715,  9180,\n",
       "         15703, 19322, 30378,  4798,  4249, 16533, 12248, 15761,  3797,\n",
       "         23640, 21332, 26114, 24196,  8412, 25555,  8806, 17195, 15741,\n",
       "         11501,  4463, 18720,  6523, 30750, 16390, 25409, 19032,  1192,\n",
       "         20408, 10301, 31946, 29912,  2931,  9093, 10539,  4648, 31751,\n",
       "         24813, 24842,  8402,  3866, 28745, 26607,  8885, 26869, 13440,\n",
       "          2361, 10348, 23461, 17655,  9538, 24317, 24002, 23711, 19507],\n",
       "        [20681, 22178,  6759, 26182, 23603, 14513, 30301,  5438, 17831,\n",
       "          9621, 25190, 19349, 11788, 24768,  5845,  6541, 27546,  2919,\n",
       "         19595,   955,  8535, 12929, 20763, 29832, 21078,  2328, 10863,\n",
       "         17892,  5082,  7884, 21420, 22107,  3242, 16307, 28868, 31800,\n",
       "         14964, 13342,  7417,  7730,  4597, 31800, 13006,  7866,  4688,\n",
       "          6265,  8481, 12363,  9197, 14503, 18132, 17563, 25826,  6762,\n",
       "         19442, 15642,  3270, 21740, 28046, 17598, 29722,  5683, 11687,\n",
       "         21518,  1491, 29782, 13832,  3291, 11776, 19105,  9432, 24817],\n",
       "        [ 2281, 30813, 29142, 31103, 17059, 27674, 24746,  2753, 14259,\n",
       "         13162, 27555, 20389, 30173, 16573,  6435,   886,  7047, 13766,\n",
       "         25416, 23059,  2787,  3705, 26428, 14210, 11678, 30209, 11519,\n",
       "         14181, 10191,  3713, 26011, 22138, 28427, 27298,  8008,  6611,\n",
       "          4927, 12607, 31287, 26706, 10243, 11705,  2863, 31464, 12841,\n",
       "           398,  8985,  6972,  8573,  4230, 20879,  6163, 13199, 19599,\n",
       "         11855, 18812, 13303,  2368, 31514, 13648, 28279, 14511, 19608,\n",
       "          9503, 11494,  9560, 14941, 31090,  5664,  1005,  6882,  7334],\n",
       "        [21642, 28651,  4239, 31270, 22920,  2733,  2614, 20510, 26668,\n",
       "           596,  2237, 26641, 23547, 27697, 17258, 18297,  7523, 22222,\n",
       "         23671, 13238,  8692, 27458,  6950,  6392, 22839, 29692,  6827,\n",
       "          1923, 22292,  4563,   638, 24575,   702, 26437, 14252, 10517,\n",
       "          2329, 11463, 31996, 26343,  4543, 14744,  4860,   171, 19283,\n",
       "         29326, 17165, 14221, 14317, 17032, 21910,  4096, 30839, 14664,\n",
       "         13125, 12924, 29338, 22510,  2294, 26486, 14079,  5307, 23237,\n",
       "          5674, 12073, 11821, 11683, 10327,  5611, 20650, 27570, 30479],\n",
       "        [29589, 23008, 21286, 14232, 19157,  3022, 25141, 13909, 31895,\n",
       "         30616,   520,  3764,  8229,  9508, 29627, 13615, 16755, 20234,\n",
       "         28629,  4910, 10158,  1790, 23695, 21559, 15206, 23256,  3946,\n",
       "         14038, 11572, 21260, 10516, 30856, 11683, 27579,  6546, 29923,\n",
       "         13803, 25248, 28312, 12703, 20991, 14064,  8449, 19784,  2763,\n",
       "         25093,  5278, 17380,  6877, 30872, 25644, 27530,  5011,  4991,\n",
       "         24005,  7766, 15518, 20865, 29947, 24200, 28585, 10705, 11637,\n",
       "          4306, 24922, 25641, 11231, 19239,  9249, 21678, 29688,  3984]],\n",
       "       dtype=int32), 10.374851, 0.0]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:34<00:00,  3.96it/s, accuracy=0.149, cost=5.87]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.99it/s, accuracy=0.161, cost=5.22]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 6.435658, training avg acc 0.124611\n",
      "epoch 1, testing avg loss 5.722437, testing avg acc 0.162499\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:33<00:00,  3.98it/s, accuracy=0.182, cost=5.46]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.30it/s, accuracy=0.167, cost=4.85]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 5.501996, training avg acc 0.174138\n",
      "epoch 2, testing avg loss 5.348543, testing avg acc 0.181107\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:53<00:00,  3.78it/s, accuracy=0.183, cost=5.25]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.36it/s, accuracy=0.226, cost=4.72]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 5.225808, training avg acc 0.188166\n",
      "epoch 3, testing avg loss 5.185829, testing avg acc 0.189879\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:55<00:00,  3.77it/s, accuracy=0.197, cost=5.06]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.92it/s, accuracy=0.231, cost=4.58]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 5.058506, training avg acc 0.198194\n",
      "epoch 4, testing avg loss 5.081503, testing avg acc 0.196356\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.208, cost=4.92]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.87it/s, accuracy=0.22, cost=4.57] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 4.939404, training avg acc 0.206569\n",
      "epoch 5, testing avg loss 4.986885, testing avg acc 0.204661\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.226, cost=4.82]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.237, cost=4.57]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 4.844442, training avg acc 0.213995\n",
      "epoch 6, testing avg loss 4.951038, testing avg acc 0.208642\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.235, cost=4.74]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.89it/s, accuracy=0.22, cost=4.54] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 4.767762, training avg acc 0.220543\n",
      "epoch 7, testing avg loss 4.901244, testing avg acc 0.212730\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.234, cost=4.67]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.89it/s, accuracy=0.226, cost=4.55]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 4.699906, training avg acc 0.225791\n",
      "epoch 8, testing avg loss 4.878349, testing avg acc 0.215059\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.244, cost=4.61]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.94it/s, accuracy=0.231, cost=4.46]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 4.651851, training avg acc 0.229174\n",
      "epoch 9, testing avg loss 4.843753, testing avg acc 0.218257\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.252, cost=4.53]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.242, cost=4.49]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 4.612467, training avg acc 0.232239\n",
      "epoch 10, testing avg loss 4.829657, testing avg acc 0.220272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.256, cost=4.48]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.237, cost=4.45]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 4.579409, training avg acc 0.234620\n",
      "epoch 11, testing avg loss 4.824563, testing avg acc 0.221824\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:41<00:00,  3.90it/s, accuracy=0.249, cost=4.46]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  6.91it/s, accuracy=0.242, cost=4.46]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 4.544057, training avg acc 0.237470\n",
      "epoch 12, testing avg loss 4.818059, testing avg acc 0.222746\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:40<00:00,  3.90it/s, accuracy=0.259, cost=4.37]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.29it/s, accuracy=0.247, cost=4.45]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 4.518851, training avg acc 0.239499\n",
      "epoch 13, testing avg loss 4.801131, testing avg acc 0.225936\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:35<00:00,  3.95it/s, accuracy=0.258, cost=4.39]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.12it/s, accuracy=0.242, cost=4.45]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 4.487974, training avg acc 0.242044\n",
      "epoch 14, testing avg loss 4.815407, testing avg acc 0.222855\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:37<00:00,  3.94it/s, accuracy=0.267, cost=4.35]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.15it/s, accuracy=0.231, cost=4.48]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 4.474396, training avg acc 0.243263\n",
      "epoch 15, testing avg loss 4.852764, testing avg acc 0.220577\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:57<00:00,  3.75it/s, accuracy=0.27, cost=4.28] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.11it/s, accuracy=0.253, cost=4.42]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 4.440702, training avg acc 0.246357\n",
      "epoch 16, testing avg loss 4.798965, testing avg acc 0.225011\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:36<00:00,  3.94it/s, accuracy=0.265, cost=4.25]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.11it/s, accuracy=0.237, cost=4.46]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 4.406374, training avg acc 0.249040\n",
      "epoch 17, testing avg loss 4.798480, testing avg acc 0.225326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:36<00:00,  3.94it/s, accuracy=0.271, cost=4.27]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.20it/s, accuracy=0.253, cost=4.48]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 4.431464, training avg acc 0.246378\n",
      "epoch 18, testing avg loss 4.807458, testing avg acc 0.226043\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:57<00:00,  3.75it/s, accuracy=0.273, cost=4.19]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.12it/s, accuracy=0.226, cost=4.49]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 4.376853, training avg acc 0.251078\n",
      "epoch 19, testing avg loss 4.796037, testing avg acc 0.226825\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [06:36<00:00,  3.94it/s, accuracy=0.253, cost=4.28]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:05<00:00,  7.13it/s, accuracy=0.242, cost=4.35]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 4.349658, training avg acc 0.253415\n",
      "epoch 20, testing avg loss 4.827734, testing avg acc 0.224312\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:23<00:00,  1.68it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.005418866"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tensor2tensor.utils import bleu_hook\n",
    "\n",
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)\n",
    "    \n",
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])\n",
    "    \n",
    "bleu_hook.compute_bleu(reference_corpus = rights,\n",
    "                       translation_corpus = results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
